{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fit Functions\n",
    "\n",
    "[Fit functions]: ../../api_static/plasmapy.analysis.fit_functions.rst\n",
    "[linregress()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.linregress.html\n",
    "[curve_fit()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html\n",
    "[fsolve()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.fsolve.html\n",
    "\n",
    "[Fit functions] are a set of callable classes designed to aid in fitting analytical functions to data.  A fit function class combines the following functionality:\n",
    "\n",
    "1. An analytical function that is callable with given parameters or fitted parameters.\n",
    "1. Curve fitting functionality (usually SciPy's [curve_fit()] or [linregress()]), which stores the fit statistics and parameters into the class.  This makes the function easily callable with the fitted parameters.\n",
    "1. Error propagation calculations.\n",
    "1. A root solver that returns either the known analytical solutions or uses SciPy's [fsolve()] to calculate the roots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from plasmapy.analysis import fit_functions as ffuncs\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = [10.5, 0.56 * 10.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents:\n",
    "\n",
    "1. [Fit function basics](#Fit-function-basics)\n",
    "1. [Fitting to data](#Fitting-to-data)\n",
    "    1. [Getting fit results](#Getting-fit-results)\n",
    "    1. [Fit function is callable](#Fit-function-is-callable)\n",
    "    1. [Plotting results](#Plotting-results)\n",
    "    1. [Root solving](#Root-solving)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit function basics\n",
    "\n",
    "[fit functions]: ../../api_static/plasmapy.analysis.fit_functions.rst\n",
    "[ExponentialPlusLinear]: ../../api/plasmapy.analysis.fit_functions.ExponentialPlusLinear.rst\n",
    "\n",
    "There is an ever expanding collection of [fit functions], but this notebook will use [ExponentialPlusLinear] as an example.\n",
    "\n",
    "A fit function class has no required arguments at time of instantiation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# basic instantiation\n",
    "explin = ffuncs.ExponentialPlusLinear()\n",
    "\n",
    "# fit parameters are not set yet\n",
    "(explin.params, explin.param_errors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each fit parameter is given a name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explin.param_names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These names are used throughout the [fit function's documentation](../../api/plasmapy.analysis.fit_functions.ExponentialPlusLinear.rst), as well as in its `__repr__`, `__str__`, and `latex_str` methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(explin, explin.__str__(), explin.latex_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fitting to data\n",
    "\n",
    "[curve_fit()]: ../../api/plasmapy.analysis.fit_functions.ExponentialPlusLinear.rst#plasmapy.analysis.fit_functions.ExponentialPlusLinear.curve_fit\n",
    "[linregress()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.linregress.html\n",
    "[Linear]: ../../api/plasmapy.analysis.fit_functions.Linear.rst#plasmapy.analysis.fit_functions.Linear.curve_fit\n",
    "\n",
    "Fit functions provide the [curve_fit()] method to fit the analytical function to a set of $(x, y)$ data.  This is typically done with SciPy's [curve_fit()](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html) function, but fitting is done with SciPy's [linregress()] for the [Linear] fit function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate some noisy data to fit to..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = (5.0, 0.1, -0.5, -8.0)  # (a, alpha, m, b)\n",
    "xdata = np.linspace(-20, 15, num=100)\n",
    "ydata = explin.func(xdata, *params) + np.random.normal(0.0, 0.6, xdata.size)\n",
    "\n",
    "plt.plot(xdata, ydata)\n",
    "plt.xlabel(\"X\", fontsize=14)\n",
    "plt.ylabel(\"Y\", fontsize=14)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[curve_fit()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html\n",
    "\n",
    "The fit function `curve_fit()` shares the same signature as SciPy's [curve_fit()], so any `**kwargs` will be passed on.  By default, only the $(x, y)$ values are needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explin.curve_fit(xdata, ydata)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting fit results\n",
    "\n",
    "After fitting, the fitted parameters, uncertainties, and [coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination), or $r^2$, values can be retrieved through their respective properties, `params`, `parame_errors`, and `rsq`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(explin.params, explin.params.a, explin.params.alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(explin.param_errors, explin.param_errors.a, explin.param_errors.alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explin.rsq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit function is callable\n",
    "\n",
    "Now that parameters are set, the fit function is callable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explin(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Associated errors can also be generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, y_err = explin(np.linspace(-1, 1, num=10), reterr=True)\n",
    "(y, y_err)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Known uncertainties in $x$ can be specified too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, y_err = explin(np.linspace(-1, 1, num=10), reterr=True, x_err=0.1)\n",
    "(y, y_err)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx-thumbnail": {
     "tooltip": "Fit Function Tutorial"
    }
   },
   "outputs": [],
   "source": [
    "# plot original data\n",
    "plt.plot(xdata, ydata, marker=\"o\", linestyle=\" \", label=\"Data\")\n",
    "ax = plt.gca()\n",
    "ax.set_xlabel(\"X\", fontsize=14)\n",
    "ax.set_ylabel(\"Y\", fontsize=14)\n",
    "\n",
    "ax.axhline(0.0, color=\"r\", linestyle=\"--\")\n",
    "\n",
    "# plot fitted curve + error\n",
    "yfit, yfit_err = explin(xdata, reterr=True)\n",
    "ax.plot(xdata, yfit, color=\"orange\", label=\"Fit\")\n",
    "ax.fill_between(\n",
    "    xdata,\n",
    "    yfit + yfit_err,\n",
    "    yfit - yfit_err,\n",
    "    color=\"orange\",\n",
    "    alpha=0.12,\n",
    "    zorder=0,\n",
    "    label=\"Fit Error\",\n",
    ")\n",
    "\n",
    "# plot annotations\n",
    "plt.legend(fontsize=14, loc=\"upper left\")\n",
    "\n",
    "txt = f\"$f(x) = {explin.latex_str}$\\n$r^2 = {explin.rsq:.3f}$\\n\"\n",
    "for name, param, err in zip(\n",
    "    explin.param_names, explin.params, explin.param_errors, strict=False\n",
    "):\n",
    "    txt += f\"{name} = {param:.3f} $\\\\pm$ {err:.3f}\\n\"\n",
    "txt_loc = [-13.0, ax.get_ylim()[1]]\n",
    "txt_loc = ax.transAxes.inverted().transform(ax.transData.transform(txt_loc))\n",
    "txt_loc[0] -= 0.02\n",
    "txt_loc[1] -= 0.05\n",
    "ax.text(\n",
    "    txt_loc[0],\n",
    "    txt_loc[1],\n",
    "    txt,\n",
    "    fontsize=\"large\",\n",
    "    transform=ax.transAxes,\n",
    "    va=\"top\",\n",
    "    linespacing=1.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Root solving\n",
    "\n",
    "[fsolve()]: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.fsolve.html\n",
    "[ExponentialPlusLinear().root_solve()]: ../../api/plasmapy.analysis.fit_functions.ExponentialPlusLinear.rst#plasmapy.analysis.fit_functions.ExponentialPlusLinear.root_solve\n",
    "[Linear().root_solve()]: ../../api/plasmapy.analysis.fit_functions.Linear.rst#plasmapy.analysis.fit_functions.Linear.root_solve\n",
    "\n",
    "An exponential plus a linear offset has no analytical solutions for its roots, except for a few specific cases.  To get around this, [ExponentialPlusLinear().root_solve()] uses SciPy's [fsolve()] to calculate it's roots.  If a fit function has analytical solutions to its roots (e.g. [Linear().root_solve()]), then the method is overridden with the known solution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root, err = explin.root_solve(-15.0)\n",
    "(root, err)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Linear().root_solve()]: ../../api/plasmapy.analysis.fit_functions.Linear.rst#plasmapy.analysis.fit_functions.Linear.root_solve\n",
    "\n",
    "Let's use [Linear().root_solve()] as an example for a known solution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lin = ffuncs.Linear(params=(1.0, -5.0), param_errors=(0.1, 0.1))\n",
    "root, err = lin.root_solve()\n",
    "(root, err)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
